import numpy as np
import torch
from torch import nn
from torch.nn.utils import weight_norm

from ..layers.wavegrad import DBlock, FiLM, UBlock, Conv1d


class Wavegrad(nn.Module):
    # pylint: disable=dangerous-default-value
    def __init__(self,
                 in_channels=80,
                 out_channels=1,
                 use_weight_norm=False,
                 y_conv_channels=32,
                 x_conv_channels=768,
                 dblock_out_channels=[128, 128, 256, 512],
                 ublock_out_channels=[512, 512, 256, 128, 128],
                 upsample_factors=[5, 5, 3, 2, 2],
                 upsample_dilations=[[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 4, 8],
                                     [1, 2, 4, 8], [1, 2, 4, 8]]):
        super().__init__()

        self.use_weight_norm = use_weight_norm
        self.hop_len = np.prod(upsample_factors)
        self.noise_level = None
        self.num_steps = None
        self.beta = None
        self.alpha = None
        self.alpha_hat = None
        self.noise_level = None
        self.c1 = None
        self.c2 = None
        self.sigma = None

        # dblocks
        self.y_conv = Conv1d(1, y_conv_channels, 5, padding=2)
        self.dblocks = nn.ModuleList([])
        ic = y_conv_channels
        for oc, df in zip(dblock_out_channels, reversed(upsample_factors)):
            self.dblocks.append(DBlock(ic, oc, df))
            ic = oc

        # film
        self.film = nn.ModuleList([])
        ic = y_conv_channels
        for oc in reversed(ublock_out_channels):
            self.film.append(FiLM(ic, oc))
            ic = oc

        # ublocks
        self.ublocks = nn.ModuleList([])
        ic = x_conv_channels
        for oc, uf, ud in zip(ublock_out_channels, upsample_factors, upsample_dilations):
            self.ublocks.append(UBlock(ic, oc, uf, ud))
            ic = oc

        self.x_conv = Conv1d(in_channels, x_conv_channels, 3, padding=1)
        self.out_conv = Conv1d(oc, out_channels, 3, padding=1)

        if use_weight_norm:
            self.apply_weight_norm()

    def forward(self, x, spectrogram, noise_scale):
        shift_and_scale = []

        x = self.y_conv(x)
        shift_and_scale.append(self.film[0](x, noise_scale))

        for film, layer in zip(self.film[1:], self.dblocks):
            x = layer(x)
            shift_and_scale.append(film(x, noise_scale))

        x = self.x_conv(spectrogram)
        for layer, (film_shift, film_scale) in zip(self.ublocks,
                                                   reversed(shift_and_scale)):
            x = layer(x, film_shift, film_scale)
        x = self.out_conv(x)
        return x

    def load_noise_schedule(self, path):
        beta = np.load(path, allow_pickle=True).item()['beta']
        self.compute_noise_level(beta)

    @torch.no_grad()
    def inference(self, x, y_n=None):
        """ x: B x D X T """
        if y_n is None:
            y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1], dtype=torch.float32).to(x)
        else:
            y_n = torch.FloatTensor(y_n).unsqueeze(0).unsqueeze(0).to(x)
        sqrt_alpha_hat = self.noise_level.to(x)
        for n in range(len(self.alpha) - 1, -1, -1):
            y_n = self.c1[n] * (y_n -
                        self.c2[n] * self.forward(y_n, x, sqrt_alpha_hat[n].repeat(x.shape[0])))
            if n > 0:
                z = torch.randn_like(y_n)
                y_n += self.sigma[n - 1] * z
            y_n.clamp_(-1.0, 1.0)
        return y_n


    def compute_y_n(self, y_0):
        """Compute noisy audio based on noise schedule"""
        self.noise_level = self.noise_level.to(y_0)
        if len(y_0.shape) == 3:
            y_0 = y_0.squeeze(1)
        s = torch.randint(0, self.num_steps - 1, [y_0.shape[0]])
        l_a, l_b = self.noise_level[s], self.noise_level[s+1]
        noise_scale = l_a + torch.rand(y_0.shape[0]).to(y_0) * (l_b - l_a)
        noise_scale = noise_scale.unsqueeze(1)
        noise = torch.randn_like(y_0)
        noisy_audio = noise_scale * y_0 + (1.0 - noise_scale**2)**0.5 * noise
        return noise.unsqueeze(1), noisy_audio.unsqueeze(1), noise_scale[:, 0]

    def compute_noise_level(self, beta):
        """Compute noise schedule parameters"""
        self.num_steps = len(beta)
        alpha = 1 - beta
        alpha_hat = np.cumprod(alpha)
        noise_level = np.concatenate([[1.0], alpha_hat ** 0.5], axis=0)
        noise_level = alpha_hat ** 0.5

        # pylint: disable=not-callable
        self.beta = torch.tensor(beta.astype(np.float32))
        self.alpha = torch.tensor(alpha.astype(np.float32))
        self.alpha_hat = torch.tensor(alpha_hat.astype(np.float32))
        self.noise_level = torch.tensor(noise_level.astype(np.float32))

        self.c1 = 1 / self.alpha**0.5
        self.c2 = (1 - self.alpha) / (1 - self.alpha_hat)**0.5
        self.sigma = ((1.0 - self.alpha_hat[:-1]) / (1.0 - self.alpha_hat[1:]) * self.beta[1:])**0.5

    def remove_weight_norm(self):
        for _, layer in enumerate(self.dblocks):
            if len(layer.state_dict()) != 0:
                try:
                    nn.utils.remove_weight_norm(layer)
                except ValueError:
                    layer.remove_weight_norm()

        for _, layer in enumerate(self.film):
            if len(layer.state_dict()) != 0:
                try:
                    nn.utils.remove_weight_norm(layer)
                except ValueError:
                    layer.remove_weight_norm()


        for _, layer in enumerate(self.ublocks):
            if len(layer.state_dict()) != 0:
                try:
                    nn.utils.remove_weight_norm(layer)
                except ValueError:
                    layer.remove_weight_norm()

        nn.utils.remove_weight_norm(self.x_conv)
        nn.utils.remove_weight_norm(self.out_conv)
        nn.utils.remove_weight_norm(self.y_conv)

    def apply_weight_norm(self):
        for _, layer in enumerate(self.dblocks):
            if len(layer.state_dict()) != 0:
                layer.apply_weight_norm()

        for _, layer in enumerate(self.film):
            if len(layer.state_dict()) != 0:
                layer.apply_weight_norm()


        for _, layer in enumerate(self.ublocks):
            if len(layer.state_dict()) != 0:
                layer.apply_weight_norm()

        self.x_conv = weight_norm(self.x_conv)
        self.out_conv = weight_norm(self.out_conv)
        self.y_conv = weight_norm(self.y_conv)


    def load_checkpoint(self, config, checkpoint_path, eval=False):  # pylint: disable=unused-argument, redefined-builtin
        state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
        self.load_state_dict(state['model'])
        if eval:
            self.eval()
            assert not self.training
            if self.use_weight_norm:
                self.remove_weight_norm()
            betas = np.linspace(config['test_noise_schedule']['min_val'],
                                config['test_noise_schedule']['max_val'],
                                config['test_noise_schedule']['num_steps'])
            self.compute_noise_level(betas)
        else:
            betas = np.linspace(config['train_noise_schedule']['min_val'],
                                config['train_noise_schedule']['max_val'],
                                config['train_noise_schedule']['num_steps'])
            self.compute_noise_level(betas)
